import torch
import torch.nn as nn
import scipy.spatial as ss
import numpy as np
import torch.nn.utils.parametrize as parametrize

class SpatialWeight(nn.Module):
    def __init__(self, input_size, CN = 10, controller_size = 128, observable_size = 64, num_layers=1, N_CO = 5, ell = 0.1):
        super(SpatialWeight, self).__init__()
        
        np.random.seed(1)
        self.pos = np.random.random([observable_size,2])
    
        #let's calculate the synpatic matrix.
        self.delpoints = ss.distance.cdist(self.pos,self.pos)
        #self.delpoints = torch.tensor(self.delpoints, dtype=torch.float32)
        self.ell = ell
        pinhib = 0.5 # hardcoded below
        self.scale = 1
        #inhib = torch.multinomial(torch.tensor(np.array([0.,1.])), observable_size, replacement=True)#np.random.choice([0,1], observable_size,[1-pinhib, pinhib])
        self.inhib = torch.tensor(np.random.choice([0,1], observable_size,[1-pinhib, pinhib])).float()

        # Define the relative distances and don't let them move.
        self.Delta = nn.Parameter(torch.tensor(self.delpoints/self.ell).float(), requires_grad = False)
        self.mask = nn.Parameter(torch.tensor(np.logical_and(self.delpoints<5*self.ell, np.eye(observable_size) == 0)).float(), requires_grad=False)
        
    def forward(self, W):
        return self.scale*(-1)**self.inhib[None,:]*torch.exp(W-self.Delta)*self.mask

class CommunicationRNN(nn.Module):
    def __init__(self, input_size, CN = 10, controller_size = 128, observable_size = 64, num_layers=1, N_CO = 5, ell = 0.1):
        super(CommunicationRNN, self).__init__()

        self.dt = 0.01
        self.spatialNet = SpatialWeight(input_size, CN = CN, controller_size = controller_size, observable_size = observable_size, num_layers=num_layers, N_CO = N_CO, ell = ell)

        #self.Wlocal = nn.Parameter(self.scale*(-1)**self.inhib[None,:]*torch.exp(self.w-self.Delta)*self.mask, requires_grad = True)

        #self.register_parameter('w', param=self.w)
        #self.register_parameter('Wlocal', param=self.Wlocal)
        
        self.controller_rnn = nn.RNN(input_size, controller_size, num_layers, batch_first=True)
        self.observed_rnn = nn.RNN(N_CO, observable_size, num_layers, batch_first=True)
        parametrize.register_parametrization(self.observed_rnn, "weight_hh_l0", SpatialWeight(input_size))
        for name, param in self.observed_rnn.named_parameters():
            print(name)
            #param.requires_grad = True
        
        self.controller_to_communication = nn.Linear(controller_size, N_CO)
        self.communication_to_observed = nn.Linear(N_CO, observable_size)
        
        # Decoder for the 'condition' neuron
        self.CN = CN
        D = torch.zeros([1,observable_size])
        D[:, self.CN] = 1
        self.BCIdecoder = nn.Linear(observable_size,1)
        self.BCIdecoder.weight.data.copy_(D)
        for param in self.BCIdecoder.parameters():
            param.requires_grad = False

        # 
        #plt.imshow(self.observed_rnn.weight_hh_l0.detach().numpy(), cmap='seismic', vmax=scale,vmin=-scale)
        #plt.colorbar()
        #plt.show()
        
    def forward(self, input_sequence):
        # Controller RNN
        controller_output, _ = self.controller_rnn(input_sequence)
        
        # Communication space
        communication_space = self.controller_to_communication(controller_output)
        
        # Observed RNN
        observed_output, _ = self.observed_rnn(communication_space)
        
        # Decoder for the 'condition' neuron
        output = self.BCIdecoder(observed_output)  # Selecting the first neuron
        
        return output.squeeze(), observed_output
